{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "348f45cd",
   "metadata": {},
   "source": [
    "# 习题\n",
    "## 习题9.1\n",
    "![image.png](./images/exercise1.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c66e055d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "模型参数的初值：\n",
      "prob_A=0.46, prob_B=0.55, prob_C=0.67\n",
      "EM算法训练过程：\n",
      "第1次：prob_A=0.4619, prob_B=0.5346, prob_C=0.6561\n",
      "第2次：prob_A=0.4619, prob_B=0.5346, prob_C=0.6561\n",
      "模型参数的极大似然估计：\n",
      "prob_A=0.4619, prob_B=0.5346, prob_C=0.6561\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "\n",
    "\n",
    "class ThreeCoinEM:\n",
    "    def __init__(self, prob, tol=1e-6, max_iter=1000):\n",
    "        \"\"\"\n",
    "        初始化模型参数\n",
    "        :param prob: 模型参数的初值\n",
    "        :param tol: 收敛阈值\n",
    "        :param max_iter: 最大迭代次数\n",
    "        \"\"\"\n",
    "        self.prob_A, self.prob_B, self.prob_C = prob\n",
    "        self.tol = tol\n",
    "        self.max_iter = max_iter\n",
    "\n",
    "    def calc_mu(self, j):\n",
    "        \"\"\"\n",
    "        （E步）计算mu\n",
    "        :param j: 观测数据y的第j个\n",
    "        :return: 在模型参数下观测数据yj来自掷硬币B的概率\n",
    "        \"\"\"\n",
    "        # 掷硬币A观测结果为正面\n",
    "        pro_1 = self.prob_A * \\\n",
    "            math.pow(self.prob_B, data[j]) * \\\n",
    "            math.pow((1 - self.prob_B), 1 - data[j])\n",
    "        # 掷硬币A观测结果为反面\n",
    "        pro_2 = (1 - self.prob_A) * math.pow(self.prob_C,\n",
    "                                             data[j]) * math.pow((1 - self.prob_C), 1 - data[j])\n",
    "        return pro_1 / (pro_1 + pro_2)\n",
    "\n",
    "    def fit(self, data):\n",
    "        count = len(data)\n",
    "        print(\"模型参数的初值：\")\n",
    "        print(\"prob_A={}, prob_B={}, prob_C={}\".format(\n",
    "            self.prob_A, self.prob_B, self.prob_C))\n",
    "        print(\"EM算法训练过程：\")\n",
    "        for i in range(self.max_iter):\n",
    "            # （E步）得到在模型参数下观测数据yj来自掷硬币B的概率\n",
    "            _mu = [self.calc_mu(j) for j in range(count)]\n",
    "            # （M步）计算模型参数的新估计值\n",
    "            prob_A = 1 / count * sum(_mu)\n",
    "            prob_B = sum([_mu[k] * data[k] for k in range(count)]) \\\n",
    "                / sum([_mu[k] for k in range(count)])\n",
    "            prob_C = sum([(1 - _mu[k]) * data[k] for k in range(count)]) \\\n",
    "                / sum([(1 - _mu[k]) for k in range(count)])\n",
    "            print('第{}次：prob_A={:.4f}, prob_B={:.4f}, prob_C={:.4f}'.format(\n",
    "                i + 1, prob_A, prob_B, prob_C))\n",
    "            # 计算误差值\n",
    "            error = abs(self.prob_A - prob_A) + \\\n",
    "                abs(self.prob_B - prob_B) + abs(self.prob_C - prob_C)\n",
    "            self.prob_A = prob_A\n",
    "            self.prob_B = prob_B\n",
    "            self.prob_C = prob_C\n",
    "            # 判断是否收敛\n",
    "            if error < self.tol:\n",
    "                print(\"模型参数的极大似然估计：\")\n",
    "                print(\"prob_A={:.4f}, prob_B={:.4f}, prob_C={:.4f}\".format(self.prob_A, self.prob_B,\n",
    "                                                                           self.prob_C))\n",
    "                break\n",
    "\n",
    "                \n",
    "# 加载数据\n",
    "data = [1, 1, 0, 1, 0, 0, 1, 0, 1, 1]\n",
    "# 模型参数的初值\n",
    "init_prob = [0.46, 0.55, 0.67]\n",
    "\n",
    "# 三硬币模型的EM模型\n",
    "em = ThreeCoinEM(prob=init_prob, tol=1e-5, max_iter=100)\n",
    "# 模型训练\n",
    "em.fit(data)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "f8faafca",
   "metadata": {},
   "source": [
    "## 习题9.3\n",
    "![image.png](./images/exercise3.png)\n",
    "### 采用sklearn的GaussianMixture计算6个参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "90e57018",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "分类结果：labels = [0 0 1 1 1 1 1 1 1 1 1 1 1 1 1]\n",
      "\n",
      "两个分量高斯混合模型的6个参数如下：\n",
      "means = [[-57.51107027  32.98489643]]\n",
      "covariances = [[ 90.24987882 429.45764867]]\n",
      "weights =  [[0.13317238 0.86682762]]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjQAAAHBCAYAAAB+PCE0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAx50lEQVR4nO3de3QUZZ7G8acvbGwuaQIJBkgQhqCCRgJ6UJOwRFdwopFB8AgKIjIKeNl1cWWQUTYQWVFggxcus+qIM86gY0RlNiOgGSSuLatoIxcHFUUuGUmIsEkjY0LSqf2jTxojYSRRuuol3885dcp6u1L1qxqn6vHtt6tclmVZAgAAMJjb7gIAAAB+KAINAAAwHoEGAAAYj0ADAACMR6ABAADGI9AAAADjEWgAAIDxCDQAAMB4BBoAAGA8Ag1wij399NPq06ePfD6frr76an355Zd2l3RCvXv31rPPPhvz/T777LNyuVwaOXJktG3RokVyuVyaNGlSk3U3bNggl8sV4wpjZ9KkSXK5XCosLIy25eXlyeVy/aD/bXJycjRnzpwW/92cOXOUk5PT6v0CsUKgAU6hV155RbfddpsmTZqkVatWqaysTBMmTLC7rBN65ZVXdPXVV9u2/y1btkT/eevWrc2uM3jwYG3cuLHV+9i9e3erbuyxdjLnAsAxXrsLAE5nc+bM0YgRI5Sfny9J8vl8uvzyy7V161ZdcMEFNld3vEGDBtm6/71796qqqkqdO3duckP/tvj4eF1yySWt3sfu3bs1d+5cx4eaxhBTVVWlffv22VwN4Hz00ACnyJdffqmtW7dqxIgR0bbBgwdLkrZv325XWY511llnqXv37tq6davq6ur08ccf6+KLL7a7LFtceuml+stf/qL6+npt2bJFZ555ps466yy7ywIcjUADnCI7duyQJPXp0yfa5vf7tWPHDuXm5kbbFi9erJ/85Cdq3769MjIyVFJSEv2sufELzz77rHr37h1dPnTokG666SZ169ZN8fHxuuqqq7R79+4mf/OrX/1K/fr1k8/nU//+/fXiiy82W/OJxtAcOXJE06ZN05lnnim/36+f/vSn2rVrV/TzxvEZzzzzjHr37q34+HjdeOONqqmp+b7T1ER6erq2bt2qHTt2yOVyqV+/fsetc6IxNBs3bpTH49Grr74qSaqvr1f//v11/fXXS4qcS5fLpcsuu0yS5HK5jhuj09w4k0mTJjVZZ9KkSRo1apQqKio0ceJEJSUl6YMPPoh+3tDQoIKCAqWmpp7wf4/vk5aWJo/Ho48//lhbtmxRenp6k8/r6+v1y1/+UsnJyerQoYOuvfba43pxFixYoOTkZPn9fs2cOfO4fezZs0fXXnutOnXqpNTUVM2ePVv19fUtqhNwEgINcIpUVlZKinxFIknhcFj19fVKS0tT586dJUkrV67UPffco2nTpum1117TpZdequuuu06HDx8+6f3cc889Kikp0fLly/W73/1OX331laZOnRr9fMOGDbr99tt1ww03qLi4WLm5uRo/frz27t170vuYPn26XnjhBS1ZskQvvfSSjhw5ottuu63JOqtXr9ZDDz2khQsXasGCBXrxxRf11FNPnfQ+pEig2bJli7Zu3aoBAwbI4/Gc9N9eeumlmjZtmmbMmKGjR49q+fLl2r9/vx577DFJ0q233qqNGzdq6dKlkiIBaOPGjZo9e3aLapSkr7/+WkOHDlVNTY3y8/PVq1ev6GcPPvigFi1apPz8fL300kuqrq7WFVdc0aKw4Ha7NWDAgOi5+G6gmTp1qpYuXar8/Hw9//zz2rVrl4YNG6ZQKCRJeuGFF3Tffffptttu0wsvvKB3331XgUAg+vdHjx7ViBEjVFlZqaKiIhUUFGjx4sV68MEHW3wuAKdgDA1witTV1UmK3Jwk6bzzztMnn3wiSVqxYoUmTZqknj176ne/+53Gjx8vSerSpYt+9atfaceOHRoyZMhJ7Wf37t0aNGiQxowZE93Pt3sEGv/5rrvuUrdu3TR06FBdfvnlat++/Ukfy2WXXaYbbrgh2ruxY8cO3XfffU3W+eSTT/Txxx9Hb+7//d//3eLBrOnp6Vq2bJk6d+7cqjFG8+fP1+rVqzVv3jwtW7ZMDz/8sLp37y5JSklJUUpKSrTX6IeMw/nzn/+sOXPmRMdGNaqpqdHDDz+sBx98MNqrk5SUpMGDB+vtt99u0a+FGnurtmzZojvvvFMvv/yyJOmLL77QihUr9NRTT+nnP/+5pMjYp379+mnFihW6++679eijjyo3NzcaUC688MImoev555/XF198ofXr1+vMM8+UFBmE/Jvf/EZz585t7WkBbEUPDXCKdOzYUVLkv+YladWqVdq0aVP0BitJw4YNU1JSkv75n/9ZF154oS666CJJ0t/+9rcTbrehoaHJ8h133KGSkhJlZ2dr5syZ+uyzz3T55ZdHPx85cqT69OmjCy+8ULfeeqt++9vfKjMzU4mJiSd9LNddd5327t2riRMnql+/fvrXf/1XffPNN03WGTVqVJObZlJSUjTUnaz09HRt375dH3744XG9EicjPj5eTzzxhB588EH169evSU9Va333fEtSr1699Mtf/vK49p07d6qmpkYzZsxQu3bt1K5du+i4qZ07d7Zov+np6QoGg/roo4+anIv3339flmXpn/7pn6JtqampOvvss7Vp0yZJ0qefftpk/FG3bt107rnnRpcbxymlpKRE63zssce0Z88eHT16tEV1Ak5BoAFOkb59+0qK/Be1FOk5ueiii1RbWxtd55e//KVGjx4tj8ejGTNmnNRYi++Olbj++uu1c+dOTZo0SRUVFbr22ms1duzY6OddunTRRx99pGXLlqlr165asGCBzjnnnJP+yikcDuvyyy/XnDlz1LdvXz366KNavXr1CY/3hxgwYIBqa2tVUlLSqkAjSWVlZZKkgwcPtngMT3Oa+4XRoEGD1K5duxP+zVNPPaVNmzY1mX72s5+1aL/p6ekqKSlRbW2tBgwYEG23LKvZ9V0uV/SzcDh83Nd1311OSUk5rsZNmzbJ66XjHmYi0ACnyPnnn6/k5OToIFUp8rPkQ4cORZeXL1+ue+65R48++qjGjRsXHQPRyOv16siRI9HlhoYGvfTSS03WmTFjhkKhkG699VY9++yzKiwsVFFRkf7v//5PkvTSSy/p97//va655ho98sgjCgaDqqqq0qpVq07qOLZv3663335bv/71r5Wfn6+rr746Ghq+rSXjXU7kjDPOUFpamiS16iun3bt3a9asWVq6dKnq6uqO+0qocR+Sjuthko4/3/v27WvRM2/S0tIUFxenb775RhdddJEuuugipaen64knnmjxL9saA13jYO5GF154oSRp/fr10baysjJ98skn0R6+tLS0aG+NFAl3jYPUpci/mwcOHFDv3r2jdYZCIT3++OMMDIaxiOLAKeJ2u1VQUKApU6Zo6tSpGjVqlBYuXNhkna5du+r1119XTk6OPvnkExUUFEhS9KZywQUXaO7cudq4caPOP/98zZgxQ5WVldGbsqToGIuZM2fqjDPO0IsvvqguXbrI7/dLigwAveeeeyRJZ599tt566y3V19frJz/5yUkdR0JCglwul1auXCmXy6WSkhL953/+Z7TOH/u/6NPT03Xw4EElJye3+G9vu+02DR48WLfffrtSUlI0evRoXX/99dEbvST1799fHTt21COPPKJ//Md/1LZt2zR27FglJyfrggsu0CuvvKJ/+7d/09GjRzV27NjoAO6T4fP59Itf/EL5+fnyeDw699xz9fTTT6u4uLjFY1OSk5OVmJh4XE9V3759NXHixGiNPXr0UH5+vpKTkzV58mRJkfFSt956q+bOnatLLrlECxcubNIzeOONN+qhhx7Stddeq5kzZ+ro0aO69957df755+sf/uEfWlQn4BT00ACn0G233aYlS5aouLhYEydOVN++fTVw4MDo57/5zW909OhRXXXVVXriiSdUWFiopKQk/c///I8k6ZprrtGUKVOUm5urvn37yufzRUNPoxUrVqhHjx6aPHmyrrnmGn3zzTf64x//GB2MfOONN2r27NlasGCBrrzySj311FP6j//4j5P+CqRXr176r//6L73xxhvKzc3V22+/rSeffFKSonX+mNLT01v1ddPTTz+t0tJSLV++PPoahauuukqTJ09uMpbH7/fr+eef18qVK3XllVdq2bJl0c9mzZqlvn37Ki0tTTk5Ofr5z3+un/70py2qIz8/X//yL/+ihx56SHl5edq7d6/WrVvX5Kf2J+tE5+Lpp5/WlClT9O///u8aN26cevXqpdLS0miInTRpkhYsWKAVK1ZozJgx6tevn7KysqJ/HxcXp9dff11dunTRuHHjNHXqVF1xxRV67rnnWlwj4BQu60RfyAIAABiCHhoAAGA8Ag0AADAegQYAABiPQAMAAIxHoAEAAMYj0AAAAOO1iQfrNTQ06Msvv1SnTp3kcrnsLgcAAJwEy7J0+PBh9ejRI/psrb+3sqNs3brVuuSSS6yOHTtaI0aMsPbs2WNZlmV98MEH1sCBA60zzjjDGj58uFVRUXHS29y3b58liYmJiYmJicnAad++fd97r3fcg/XS0tJ0yy236JZbbtHcuXP1+eef6/XXX1ffvn01fvx4TZs2TXfeeac6duyo3//+9ye1zerqanXu3Fn79u1TfHz8KT4CAADwYwiFQkpNTVVVVVX0Sdgn4qhAU1lZqW7dumn//v1KTk7Wxo0bdcUVV6i4uFijRo3SwYMH5fV6FQwGlZ2drcrKSnXo0OF7txsKheT3+1VdXU2gAQDAEC25fztqUHBCQoJSUlK0bt06SdLatWuVkZGhQCCgIUOGRF+Cl5GRoXA4rGAwaGe5AADAIRw1KNjr9erFF1/UZZddpilTpqhjx4567733tHjxYiUmJkbXc7vdSkhIUEVFRbPbqa2tbfJm2VAodMprBwAA9nFUD80333yjiRMnau7cuQoGg5owYYImT54sSfruN2OWZZ3wF0vz58+X3++PTqmpqae8dgAAYB9HBZrXX39dR48e1cyZM3Xeeedp4cKFeuedd5SUlKTKysroeuFwWFVVVUpOTm52O7NmzVJ1dXV02rdvX6wOAQAA2MBRgcbj8SgcDkeXLctSQ0ODLrvsMm3atEn19fWSpM2bN8vr9WrQoEHNbicuLk7x8fFNJgAAcPpyVKC5+OKLFQqFtHjxYpWVlen+++9XamqqLr74YiUlJSk/P19lZWUqKCjQ6NGj1b59e7tLBgAADuCoQJOUlKSioiI988wzOuecc/T222/r5ZdfVlxcnIqKilRcXKy0tDTV1NSosLDQ7nIBAIBDOOo5NKcKz6EBAMA8xj6HBgAAoDUINAAAwHgEGgAAYDwCDQAAaLVAQMrNlVJSIvNAwJ46HPXqAwAAYI5AQMrJkSxLCoel8nKppETasEHKyoptLfTQAACAVpk371iYkSJzy4q0xxqBBgAAtMq2bcfCTKNwONIeawQaAADQKunpksfTtM3jibTHGoEGAAC0ygMPSC7XsVDj8USWZ8+OfS0EGgAA0CpZWZEBwMOHSz17RualpVJmZuxr4VdOAACg1bKypDVr7K6CHhoAAHAaINAAAADjEWgAAIDxCDQAAMB4BBoAAGA8Ag0AADAegQYAABiPQAMAAIxHoAEAAMYj0AAAAOMRaAAAgPEINAAAwHgEGgAAYDwCDQAAMB6BBgAAGI9AAwAAjEegAQAAxiPQAAAA4xFoAACA8Qg0AADAeAQaAABgPAINAAAwHoEGAAAYj0ADAACMR6ABAADGI9AAAADjOS7Q1NXVaerUqerUqZMGDBig9957T5IUDAaVkZEhn8+nESNG6MCBAzZXCgAAnMJxgWbRokXavXu3Nm/erHHjxmnChAlqaGjQmDFjlJeXp507d8rn82n69Ol2lwoAgO0CASk3V0pJicwDAbsrsofLsizL7iK+LS0tTatWrdLAgQP19ddfa+3atUpISNDo0aN18OBBeb1eBYNBZWdnq7KyUh06dPjebYZCIfn9flVXVys+Pj4GRwEAwKkXCEg5OZJlSeGw5PFILpe0YYOUlWV3dT9cS+7fjuqhKS8v165du1RaWiq/369hw4Zp4MCB2rhxo4YMGSKv1ytJysjIUDgcVjAYbHY7tbW1CoVCTSYAAE438+YdCzNSZG5Zkfa2xlGBZv/+/XK73Xr33Xe1ZcsW9e/fX9OmTVN5ebkSExOj67ndbiUkJKiioqLZ7cyfP19+vz86paamxuoQAACImW3bjoWZRuFwpL2tcVSgOXLkiMLhsPLz89W7d2/dddddevPNN9XQ0KDvfjNmWZZcLlez25k1a5aqq6uj0759+2JRPgAAMZWeHvma6ds8nkh7W+OoQOP3+yVJXbp0kSR17dpVlmWpZ8+eqqysjK4XDodVVVWl5OTkZrcTFxen+Pj4JhMAAKebBx6IjJlpDDWNY2hmz7a3Ljs4KtCkpaWpXbt2+vTTTyVJFRUV8ng8Gjp0qDZt2qT6+npJ0ubNm+X1ejVo0CA7ywUAwFZZWZEBwMOHSz17RualpVJmpt2VxZ7jfuV03XXXKRQKafny5brvvvtUU1Oj1atXq1+/fho3bpxuv/123XHHHfL7/XruuedOapv8ygkAAPMY+ysnSVq2bJksy1J6eroqKiq0ZMkSud1uFRUVqbi4WGlpaaqpqVFhYaHdpQIAAIdwXA/NqUAPDQAA5jG6hwYAAFPwlF7n8NpdAAAAJvruU3rLy6WSktPnKb2moYcGAIBW4Cm9zkKgAQCgFXhKr7MQaAAAaAWe0ussBBoAAFqBp/Q6C4EGAIBW4Cm9zsKvnAAAaKWsLGnNGrurgEQPDQAAOA0QaAAAgPEINAAAwHgEGgAAYDwCDQAAMB6BBgAAGI9AAwAAjEegAQAAxiPQAACMFQhIublSSkpkHgjYXRHswpOCAQBGCgSknBzJsiJvuS4vl0pKIq8jyMqyuzrEGj00AAAjzZt3LMxIkbllRdrR9hBoAABG2rbtWJhpFA5H2tH2EGgAAEZKT5c8nqZtHk+kHW0PgQYAYKQHHpBcrmOhxuOJLM+ebW9dsAeBBgBgpKysyADg4cOlnj0j89JSKTPT7spgB37lBAAwVlaWtGaN3VXACeihAQAAxiPQAAAA4xFoAACA8Qg0AADAeAQaAABgPAINAAAwHoEGAAAYj0ADAACMR6ABAADGI9AAAADjEWgAAIDxCDQAgFYLBKTcXCklJTIPBOyuCG2VYwPNW2+9JZfLpQ0bNkiSgsGgMjIy5PP5NGLECB04cMDeAgGgjQsEpJwc6Y03pL/+NTLPySHUwB6ODDR1dXW64447ossNDQ0aM2aM8vLytHPnTvl8Pk2fPt3GCgEA8+ZJliWFw5HlcDiyPG+evXWhbfLaXUBzFi9erG7duqmsrEySVFpaqkOHDmnOnDnyer3Kz89Xdna2jhw5og4dOthcLQC0Tdu2HQszjcLhSDsQa47roSkrK9PDDz+spUuXRtsCgYCGDBkirzeSvzIyMhQOhxUMBu0qEwDavPR0yeNp2ubxRNqBWHNcoLn77rs1depU9e/fP9pWXl6uxMTE6LLb7VZCQoIqKiqa3UZtba1CoVCTCQDw43rgAcnlOhZqPJ7I8uzZ9taFtslRgWbt2rX64IMPNLuZ/zdYlnXcssvlanY78+fPl9/vj06pqamnpF4AaMuysqQNG6Thw6WePSPz0lIpM9PuytAWOSrQ/OEPf9D+/fvVo0cPde7cWdXV1crLy1P37t1VWVkZXS8cDquqqkrJycnNbmfWrFmqrq6OTvv27YvVIQBAm5KVJa1ZI5WVReaEGdjFZX2368NGX331lb7++uvo8gUXXKAnn3xSPXr0UF5eng4dOiSv16v3339fw4YNU2Vlpdq3b/+92w2FQvL7/aqurlZ8fPypPAQAAPAjacn921E9NImJierdu3d0crvdSk5OVnZ2tpKSkpSfn6+ysjIVFBRo9OjRJxVmAADA6c9RgeZE3G63ioqKVFxcrLS0NNXU1KiwsNDusgAAgEM48jk0jaqqqqL/PHjwYG3ZssW+YgAAgGMZ0UMDAADw9xBoAACA8Qg0AADAeAQaAABgPAINAAAwHoEGAAAYj0ADAACMR6ABAADGI9AAAADjEWgAAIDxCDQAAMB4BBoAAGA8Ag0AADAegQYAABiPQAMAAIxHoAEAAMYj0AAAAOMRaADAYIGAlJsrpaRE5oGA3RUB9vDaXQAAoHUCASknR7IsKRyWysulkhJpwwYpK8vu6oDYoocGAFrJ7t6RefOOhRkpMresSDvQ1tBDAwCt4ITekW3bjoWZRuFwpB1oa+ihAYBWcELvSHq65PE0bfN4Iu1AW0OgAYBWcELvyAMPSC7XsVDj8USWZ8+OXQ2AUxBoAKAVnNA7kpUV+Ypr+HCpZ8/IvLRUysyMXQ2AU7gsy7LsLuJUC4VC8vv9qq6uVnx8vN3lADgNfHcMTWPvCIEC+PG05P5NDw0AtAK9I4Cz8CsnAGilrCxpzRq7qwAg0UMDAABOAwQaAABgPAINAGPZ/aReAM7BGBoARnLCk3oBOAc9NABazc4eEic8qReAc9BDA6BV7O4hccKTegE4Bz00AFrF7h4SJzypF4BzEGgAtIrdPSS8xwjAtxFoALSK3T0kPKkXwLc5LtDs2rVLw4YNU6dOnZSTk6M9e/ZIkoLBoDIyMuTz+TRixAgdOHDA5kqBts0JPSSNT+otK4vMCTNA2+W4QDNlyhT16tVL27dvV9euXXXnnXeqoaFBY8aMUV5ennbu3Cmfz6fp06fbXSrQptFDAsBJHPW27aNHj+qMM87Q9u3bNWDAAL322mu64YYb9Oqrr2rUqFE6ePCgvF6vgsGgsrOzVVlZqQ4dOnzvdnnbNgAA5jH2bdt1dXVasGCB+vTpI0k6ePCgfD6fAoGAhgwZIq838ivzjIwMhcNhBYNBO8sFAAAO4ajn0HTo0EH33nuvpEi4efzxx3XTTTepvLxciYmJ0fXcbrcSEhJUUVHR7HZqa2tVW1sbXQ6FQqe2cAAAYCtH9dA0qq+v1/jx4+V2u1VQUCBJ+u43Y5ZlyeVyNfv38+fPl9/vj06pqamnvGYAAGAfxwWahoYGjRs3Tp999pnWrFkjn8+n7t27q7KyMrpOOBxWVVWVkpOTm93GrFmzVF1dHZ327dsXq/IBAIANHBdoCgoK9Nlnn2n9+vXq0qWLJGno0KHatGmT6uvrJUmbN2+W1+vVoEGDmt1GXFyc4uPjm0wAAOD05ahAU15ersWLF2v58uWSpKqqKlVVVSk7O1tJSUnKz89XWVmZCgoKNHr0aLVv397migEAgBM4KtCsW7dOoVBImZmZSkhIiE579+5VUVGRiouLlZaWppqaGhUWFtpdLgAAcAhHPYfmVOE5NAAAmMfY59AAAAC0BoEGMFggIOXmSikpkXkgYHdFAGAPRz1YD8DJCwSknBzJsqRwWCovl0pKIu9XysqyuzoAiC16aABDzZt3LMxIkbllRdoBoK0h0ACG2rbtWJhpFA5H2gGgrSHQAD+AnWNY0tMlj6dpm8cTaQeAtoZAA7RS4xiWN96Q/vrXyDwnJ3ah5oEHJJfrWKjxeCLLs2fHZv8A4CQEGqCV7B7DkpUVGQA8fLjUs2dkXloqZWbGZv8A4CT8ygloJSeMYcnKktasid3+AMCp6KEBWokxLADgHAQaoJUYwwIAzkGgAVqJMSwA4ByMoQF+AMawAIAz0EMDAACMR6ABAADGI9AAAADjEWgAAIDxCDQAAMB4BBoAAGA8Ag2MZeebrgEAzsJzaGCkxjddN74csrxcKimJPOguK8vu6gAAsUYPDYxk95uuAQDOQqCBkZzwpmsAgHMQaGAk3nQNAPg2Ag1azc5BubzpGgDwbQQatErjoNw33pD++tfIPCcndqGGN10DAL6NXzmhVZoblOvxRNpj9fZp3nQNAGhEDw1ahUG5AAAnIdCgVRiUCwBwEgINWoVBuQAAJ2lxoPn0009PRR0wDINyAQBO4rIsy2rJH/h8Pp177rkaO3asxo4dqz59+pyq2n40oVBIfr9f1dXVio+Pt7scAABwElpy/25xD81XX32l+++/X9u3b9eFF16oIUOGqLCwUPv27Wt1wQAAAD9Ei3tovq2+vl4rVqzQL37xC4VCIWVmZuqRRx5RpsO+d6CHBgAA87Tk/t2q59Ds3LlTq1at0ssvv6yPPvpIubm5Gjt2rP72t7/puuuu05dfftmqwgEAAFqjxYEmPT1dn3/+ua688kpNnz5dI0eOVIcOHSRJe/bsUVJS0o9eJAAAwN/T4kAzc+ZM/exnP1OnTp2O++yss87Sli1bfpTCAAAATlaLBwVPmDCh2TATC8FgUBkZGfL5fBoxYoQOHDhgSx0AAMBZjHmwXkNDg8aMGaO8vDzt3LlTPp9P06dPt7ssAADgAMa8nLK0tFSHDh3SnDlz5PV6lZ+fr+zsbB05ciQ6hgcAALRNxvTQBAIBDRkyRF5vJINlZGQoHA4rGAwet25tba1CoVCTCQAAnL6MCTTl5eVKTEyMLrvdbiUkJKiiouK4defPny+/3x+dUlNTY1kqAACIMWMCjSR99xmAlmXJ5XIdt96sWbNUXV0dnXiKMQAApzdjxtB0795dO3bsiC6Hw2FVVVUpOTn5uHXj4uIUFxcXy/IAAICNjOmhGTp0qDZt2qT6+npJ0ubNm+X1ejVo0CCbKwMAAHYzJtBkZ2crKSlJ+fn5KisrU0FBgUaPHq327dvbXRoAALCZMYHG7XarqKhIxcXFSktLU01NjQoLC+0uCwAAOIAxY2gkafDgwbxaAQAAHMeYHhoAAIATIdAAAADjEWgAAIDxCDQAAMB4BBoAAGA8Ag0AADAegQYAABiPQAMAAIxHoAEAAMYj0AAAAOMRaAAAgPEINAAAwHgEGgAAYDwCDQAAMB6BBgAAGI9AAwAAjEegAQAAxiPQAAAA4xFoAACA8Qg0AADAeAQaAABgPAINAAAwHoEGAAAYj0ADAACMR6ABAADGI9AAAADjEWgAAIDxCDQAAMB4BBoAAGA8Ag0AADAegQYAABiPQAMAAIxHoAEAAMYj0AAAAOMRaAAAgPEINAAAwHiOCjS7du3SsGHD1KlTJ+Xk5GjPnj3Rz4LBoDIyMuTz+TRixAgdOHDAxkoBAICTOCrQTJkyRb169dL27dvVtWtX3XnnnZKkhoYGjRkzRnl5edq5c6d8Pp+mT59uc7UAAMApXJZlWXYXIUlHjx7VGWecoe3bt2vAgAF67bXXdMMNN6i6ulpvvvmmRo0apYMHD8rr9SoYDCo7O1uVlZXq0KHD9247FArJ7/erurpa8fHxMTgaAADwQ7Xk/u2YHpq6ujotWLBAffr0kSQdPHhQPp9PkhQIBDRkyBB5vV5JUkZGhsLhsILBoG31AgAA5/DGeoc333yzVq9efVz7jBkzdP/990uKhJvHH39cN910kySpvLxciYmJ0XXdbrcSEhJUUVHR7D5qa2tVW1sbXQ6FQj/mIQAAAIeJeaBZtGiR5s6de1x7586dJUn19fUaP3683G63CgoKop9/95sxy7Lkcrma3cf8+fOb3QcAADg9xTzQJCUlKSkpqdnPGhoaNG7cOO3atUslJSXRr5y6d++uHTt2RNcLh8OqqqpScnJys9uZNWuW7rnnnuhyKBRSamrqj3gUAADASRwzhkaSCgoK9Nlnn2n9+vXq0qVLtH3o0KHatGmT6uvrJUmbN2+W1+vVoEGDmt1OXFyc4uPjm0wAAOD05ZhAU15ersWLF2v58uWSpKqqKlVVVamhoUHZ2dlKSkpSfn6+ysrKVFBQoNGjR6t9+/Y2Vw0AAJzAMYFm3bp1CoVCyszMVEJCQnTau3ev3G63ioqKVFxcrLS0NNXU1KiwsNDukgEAgEM45jk0pxLPoQEAwDxGPocGAACgtQg0AADAeAQaAABgPAINAAAwHoEGAAAYj0ADAACMR6ABAADGI9AAAADjEWgAAIDxCDQAAMB4BBoAAGA8Ag0AADAegQYAABiPQAMAAIxHoAEAAMYj0AAAAOMRaAAAgPEINAAAwHgEGgAAYDwCDQAAMB6BBgAAGI9AAwAAjEegAQAAxiPQAAAA4xFoAACA8Qg0AADAeAQaAABgPAINAAAwHoEGAAAYj0ADAACMR6ABAADGI9AAAADjEWgAAIDxCDQAAMB4BBoAAGA8Ag0AADCeIwPNW2+9JZfLpQ0bNkTbgsGgMjIy5PP5NGLECB04cMC+AgEAgKM4LtDU1dXpjjvuaNLW0NCgMWPGKC8vTzt37pTP59P06dNtqhAAADiN1+4Cvmvx4sXq1q2bysrKom2lpaU6dOiQ5syZI6/Xq/z8fGVnZ+vIkSPq0KGDjdUCAAAncFQPTVlZmR5++GEtXbq0SXsgENCQIUPk9UbyV0ZGhsLhsILBoB1lAgAAh4l5D83NN9+s1atXH9c+Y8YMBYNBTZ06Vf3792/yWXl5uRITE6PLbrdbCQkJqqioaHYftbW1qq2tjS6HQqEfqXoAAOBEMQ80ixYt0ty5c49r/9///V899dRTeu6555r9O8uyjlt2uVzNrjt//vxm9wEAAE5PLuu7ScEmt9xyi1auXCmfzydJqq6uVocOHfTkk0/qiy++0Pr16/XnP/9ZkhQOh9W+fXutX79eWVlZx22ruR6a1NRUVVdXKz4+PjYHBAAAfpBQKCS/339S92/HBJqvvvpKX3/9dXT5ggsu0JNPPqm8vDwFg0Hl5eXp0KFD8nq9ev/99zVs2DBVVlaqffv237vtlpwQAADgDC25fztmUHBiYqJ69+4dndxut5KTk9WxY0dlZ2crKSlJ+fn5KisrU0FBgUaPHn1SYQYAAJz+HBNo/h63262ioiIVFxcrLS1NNTU1KiwstLssAADgEI57Dk2jqqqqJsuDBw/Wli1b7CkGAAA4mhE9NAAAAH8PgQYAABiPQAMAAIxHoAEAAMYj0AAAAOMRaAAAgPEINAAAwHgEGgAAYDwCDQAAMB6BBgAAGI9AY7JAQMrNlVJSIvNAwO6KAACwhWPf5YTvEQhIOTmSZUnhsFReLpWUSBs2SFlZdlcHAEBM0UNjqnnzjoUZKTK3rEg7AABtDIHGVNu2HQszjcLhSDsAAG0MgcZU6emSx9O0zeOJtAMA0MYQaEz1wAOSy3Us1Hg8keXZs+2tCwAAGxBoTJWVFRkAPHy41LNnZF5aKmVm2l0ZAAAxx6+cTJaVJa1ZY3cVAADYjh4aAABgPAINAAAwHoEGAAAYj0ADAACMR6ABAADGI9AAAADjEWgAAIDxCDQAAMB4BBoAAGA8Ag0AADAegQYAABiPQAMAAIxHoAEAAMYj0AAAAOMRaAAAgPEINAAAwHgEGgAAYDwCDQAAMB6BBgAAGM9Rgaaurk5Tp05Vp06dNGDAAL333nvRz4LBoDIyMuTz+TRixAgdOHDAxkoBAICTOCrQLFq0SLt379bmzZs1btw4TZgwQZLU0NCgMWPGKC8vTzt37pTP59P06dNtrhYAADiFy7Isy+4iGqWlpWnVqlUaOHCgvv76a61du1ajR49WaWmpRo0apYMHD8rr9SoYDCo7O1uVlZXq0KHD9243FArJ7/erurpa8fHxMTgSAADwQ7Xk/u2YHpry8nLt2rVLpaWl8vv9GjZsmAYOHCi3261AIKAhQ4bI6/VKkjIyMhQOhxUMBpvdVm1trUKhUJMJAACcvmIeaG6++WZ17tz5uOnXv/613G633n33XW3ZskX9+/fXtGnTJEXCTmJi4rGi3W4lJCSooqKi2X3Mnz9ffr8/OqWmpsbk2AAAgD28sd7hokWLNHfu3OPay8rKFA6HlZ+fr969e+uuu+5SZmamjh49Kkn67jdjlmXJ5XI1u49Zs2bpnnvuiS6HQiFCDQAAp7GYB5qkpCQlJSUd13748GFJUpcuXSRJXbt2lWVZOnTokLp3764dO3ZE1w2Hw6qqqlJycnKz+4iLi1NcXNwpqB4AADiRY8bQpKWlqV27dvr0008lSRUVFfJ4PEpMTNTQoUO1adMm1dfXS5I2b94sr9erQYMG2VkyAABwCMcEGp/Pp5EjR2rOnDn6/PPP9dhjjyk3N1der1fZ2dlKSkpSfn6+ysrKVFBQoNGjR6t9+/Z2lw0AABzAMYFGkpYtWybLspSenq6KigotWbJEUmQQcFFRkYqLi5WWlqaamhoVFhbaXC0AAHAKRz2H5lThOTQAAJjHyOfQAAAAtBaBBgAAGI9AAwAAjEegAQAAxiPQAAAA4xFoAACA8Qg0AADAeASaHyIQkHJzpZSUyDwQsLsiAADapJi/nPK0EQhIOTmSZUnhsFReLpWUSBs2SFlZdlcHAECbQg9Na82bdyzMSJG5ZUXaAQBATBFoWmvbtmNhplE4HGkHAAAxRaBprfR0yeNp2ubxRNoBAEBMEWha64EHJJfrWKjxeCLLs2fbWxcAAG0Qgaa1srIiA4CHD5d69ozMS0ulzEy7KwMAoM3hV04/RFaWtGaN3VUAANDm0UMDAACMR6ABAADGI9AAAADjEWgAAIDxCDQAAMB4BBoAAGA8Ag0AADAegQYAABiPQAMAAIxHoAEAAMYj0AAAAOO1iXc5WZYlSQqFQjZXAgAATlbjfbvxPv73tIlAc/jwYUlSamqqzZUAAICWOnz4sPx+/99dx2WdTOwxXENDg7788kt16tRJLpfrR912KBRSamqq9u3bp/j4+B912yZo68cvcQ44/rZ9/BLnoK0fv3TqzoFlWTp8+LB69Oght/vvj5JpEz00brdbKSkpp3Qf8fHxbfZfZInjlzgHHH/bPn6Jc9DWj186Nefg+3pmGjEoGAAAGI9AAwAAjEeg+YHi4uKUn5+vuLg4u0uxRVs/folzwPG37eOXOAdt/fglZ5yDNjEoGAAAnN7ooQEAAMYj0AAAAOMRaAAAgPEIND9AMBhURkaGfD6fRowYoQMHDthdUkzt2rVLw4YNU6dOnZSTk6M9e/bYXZJt3nrrLblcLm3YsMHuUmKqrq5OU6dOVadOnTRgwAC99957dpcUc9u2bdOll16qTp066corr9TevXvtLumU279/v4YNG6YPP/ww2tbWrofNnYO2dE1s7vgb2XU9JNC0UkNDg8aMGaO8vDzt3LlTPp9P06dPt7usmJoyZYp69eql7du3q2vXrrrzzjvtLskWdXV1uuOOO+wuwxaLFi3S7t27tXnzZo0bN04TJkywu6SYu/baa5WXl6dPPvlEvXv31uTJk+0u6ZSaOnWqevToobfeeiva1tauh82dA6ntXBNPdPySzddDC62yfv16Kz4+3qqrq7Msy7I++OADy+fzWV9//bXNlcVGbW2t5XK5rI8++siyLMv605/+ZMXHx9tclT0eeeQR67LLLrP8fr/15ptv2l1OTPXt29f68MMPLcuyrMOHD1tFRUVWOBy2uarYOXDggCXJ2r9/v2VZlvXOO+9Y7du3t7mqU6uystL64osvLEnW5s2bLctqe9fD5s5BW7omNnf8jey8HtJD00qBQEBDhgyR1xt5e0RGRobC4bCCwaDNlcVGXV2dFixYoD59+kiSDh48KJ/PZ3NVsVdWVqaHH35YS5cutbuUmCsvL9euXbtUWloqv9+vYcOGaeDAgd/7vpXTSUJCglJSUrRu3TpJ0tq1a5WRkWFvUadYYmKievfu3aStrV0PmzsHbema2NzxS/ZfD9vOledHVl5ersTExOiy2+1WQkKCKioqbKwqdjp06KB7771XPp9PdXV1evzxx3XTTTfZXVbM3X333Zo6dar69+9vdykxt3//frndbr377rvasmWL+vfvr2nTptldVkx5vV69+OKLmjp1quLi4rRkyRL99re/tbusmGvr10OJa6Jk//WQQPMDWN95JqFlWT/627ydrr6+XuPHj5fb7VZBQYHd5cTU2rVr9cEHH2j27Nl2l2KLI0eOKBwOKz8/X71799Zdd92lN998U0ePHrW7tJj55ptvNHHiRM2dO1fBYFATJkw47cfQnAjXw4i2ek10wvWQQNNK3bt3V2VlZXQ5HA6rqqpKycnJNlYVWw0NDRo3bpw+++wzrVmz5rTtXj2RP/zhD9q/f7969Oihzp07q7q6Wnl5eVq5cqXdpcVE4xtwu3TpIknq2rWrLMvSoUOH7Cwrpl5//XUdPXpUM2fO1HnnnaeFCxfqnXfe0datW+0uLaa4Hka05WuiE66HBJpWGjp0qDZt2qT6+npJ0ubNm+X1ejVo0CCbK4udgoICffbZZ1q/fn30ptaWLFy4UJ988ok+/PBDffjhh+rUqZOefvppjRw50u7SYiItLU3t2rXTp59+KkmqqKiQx+Np8tXD6c7j8SgcDkeXLctSQ0NDdCxJW8H1MKItXxOdcD0k0LRSdna2kpKSlJ+fr7KyMhUUFGj06NFq37693aXFRHl5uRYvXqzly5dLkqqqqlRVVaWGhgabK4udxoFxjZPb7VZycrI6duxod2kx4fP5NHLkSM2ZM0eff/65HnvsMeXm5rapm/nFF1+sUCikxYsXq6ysTPfff79SU1PVr18/u0uLqbZ+PZS4JjrhekigaSW3262ioiIVFxcrLS1NNTU1KiwstLusmFm3bp1CoZAyMzOVkJAQndrCQ8VwzLJly2RZltLT01VRUaElS5bYXVJMJSUlqaioSM8884zOOeccvf3223r55ZfVrl07u0uLqbZ+PZS4JjoBb9sGAADGo4cGAAAYj0ADAACMR6ABAADGI9AAAADjEWgAAIDxCDQAAMB4BBoAAGA8Ag0AADAegQYAABiPQAPASAcPHlRiYqLWrl0rSVq8eLEGDRrU5GWRANoOXn0AwFhPPvmkli9frjfffFNnn322Xn31VWVmZtpdFgAbEGgAGKuhoUGXXHKJJOm8887TihUrbK4IgF34ygmAsdxut6ZMmaJNmzZp2rRpdpcDwEb00AAwVk1Njc4//3ylpaXJsiytW7fO7pIA2IQeGgDGevDBB9W3b1/98Y9/1K5du7Ry5Uq7SwJgE3poABjpL3/5i4YMGaL3339f5557rv70pz9p8uTJ+vjjj5WQkGB3eQBijEADAACMx1dOAADAeAQaAABgPAINAAAwHoEGAAAYj0ADAACMR6ABAADGI9AAAADjEWgAAIDxCDQAAMB4BBoAAGA8Ag0AADDe/wMNiPkpdYibwgAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from sklearn.mixture import GaussianMixture\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# 解决 UserWarning: Glyph 8722 (\\N{MINUS SIGN}) missing from current font.问题\n",
    "plt.rc('axes', unicode_minus=False) \n",
    "\n",
    "# 初始化观测数据\n",
    "# 在 NumPy 中，reshape 函数用于重新调整数组的大小。当你将参数设置为 -1 时，它告诉 NumPy 自动计算该维度的大小，以便保持数组元素的总数不变。\n",
    "data = np.array([-67, -48, 6, 8, 14, 16, 23, 24, 28,\n",
    "                29, 41, 49, 56, 60, 75]).reshape(-1, 1)\n",
    "\n",
    "# 设置n_components=2，表示两个分量高斯混合模型\n",
    "gmm_model = GaussianMixture(n_components=2)\n",
    "# 对模型进行参数估计\n",
    "gmm_model.fit(data)\n",
    "# 对数据进行聚类\n",
    "labels = gmm_model.predict(data)\n",
    "\n",
    "# 得到分类结果\n",
    "print(\"分类结果：labels = {}\\n\".format(labels))\n",
    "print(\"两个分量高斯混合模型的6个参数如下：\")\n",
    "# 得到参数u1,u2\n",
    "print(\"means =\", gmm_model.means_.reshape(1, -1))\n",
    "# 得到参数sigma1, sigma1\n",
    "print(\"covariances =\", gmm_model.covariances_.reshape(1, -1))\n",
    "# 得到参数a1, a2\n",
    "print(\"weights = \", gmm_model.weights_.reshape(1, -1))\n",
    "\n",
    "# 绘制观测数据的聚类情况\n",
    "for i in range(0, len(labels)):\n",
    "    if labels[i] == 0:\n",
    "        plt.scatter(i, data.take(i), s=15, c='red')\n",
    "    elif labels[i] == 1:\n",
    "        plt.scatter(i, data.take(i), s=15, c='blue')\n",
    "plt.title('Gaussian Mixture Model')\n",
    "plt.xlabel('x')\n",
    "plt.ylabel('y')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68904cd9",
   "metadata": {},
   "source": [
    "### 自编程实现高斯混合模型的EM算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8ad53159",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "alpha : [[0.56950675 0.43049325]]\n",
      "mean : [[27.41762854 12.35515017]]\n",
      "std : [[ 268.17311145 2772.33989897]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import itertools\n",
    "\n",
    "\n",
    "class MyGMM:\n",
    "    def __init__(self, alphas_init, means_init, covariances_init, tol=1e-6, n_components=2, max_iter=50):\n",
    "        # (1)设置参数的初始值\n",
    "        # 分模型权重\n",
    "        self.alpha_ = np.array(\n",
    "            alphas_init, dtype=\"float16\").reshape(n_components, 1)\n",
    "        # 分模型均值\n",
    "        self.mean_ = np.array(\n",
    "            means_init, dtype=\"float16\").reshape(n_components, 1)\n",
    "        # 分模型标准差（方差的平方）\n",
    "        self.covariances_ = np.array(\n",
    "            covariances_init, dtype=\"float16\").reshape(n_components, 1)\n",
    "        # 迭代停止的阈值\n",
    "        self.tol = tol\n",
    "        # 高斯混合模型分量个数\n",
    "        self.K = n_components\n",
    "        # 最大迭代次数\n",
    "        self.max_iter = max_iter\n",
    "        # 观测数据\n",
    "        self._y = None\n",
    "        # 实际迭代次数\n",
    "        self.n_iter_ = 0\n",
    "\n",
    "    def gaussian(self, mean, convariances):\n",
    "        \"\"\"计算高斯分布概率密度\"\"\"\n",
    "        return 1 / np.sqrt(2 * np.pi * convariances) * np.exp(\n",
    "            -(self._y - mean) ** 2 / (2 * convariances))\n",
    "\n",
    "    def update_r(self, mean, convariances, alpha):\n",
    "        \"\"\"更新r_jk 分模型k对观测数据yi的响应度\"\"\"\n",
    "        r_jk = alpha * self.gaussian(mean, convariances)\n",
    "        return r_jk / r_jk.sum(axis=0)\n",
    "\n",
    "    def update_params(self, r):\n",
    "        \"\"\"更新u al si 每个分模型k的均值、权重、方差的平方\"\"\"\n",
    "        u = self.mean_[-1]\n",
    "        _mean = ((r * self._y).sum(axis=1) / r.sum(axis=1)).reshape(self.K, 1)\n",
    "        _covariances = ((r * (self._y - u) ** 2).sum(axis=1) /\n",
    "                        r.sum(axis=1)).reshape(self.K, 1)\n",
    "        _alpha = (r.sum(axis=1) / self._y.size).reshape(self.K, 1)\n",
    "        return _mean, _covariances, _alpha\n",
    "\n",
    "    def judge_stop(self, mean, covariances, alpha):\n",
    "        \"\"\"中止条件判断\"\"\"\n",
    "        a = np.linalg.norm(self.mean_ - mean)\n",
    "        b = np.linalg.norm(self.covariances_ - covariances)\n",
    "        c = np.linalg.norm(self.alpha_ - alpha)\n",
    "        return True if np.sqrt(a ** 2 + b ** 2 + c ** 2) < self.tol else False\n",
    "\n",
    "    def fit(self, y):\n",
    "        self._y = np.copy(np.array(y))\n",
    "        \"\"\"迭代训练获得预估参数\"\"\"\n",
    "        # (2)E步：计算分模型k对观测数据yi的响应度\n",
    "        r = self.update_r(self.mean_, self.covariances_, self.alpha_)\n",
    "        # 更新r_jk 分模型k对观测数据yi的响应度\n",
    "        _mean, _covariances, _alpha = self.update_params(r)\n",
    "        # 更新u al si 每个分模型k的均值、权重、方差的平方\n",
    "        for i in range(self.max_iter):\n",
    "            if not self.judge_stop(_mean, _covariances, _alpha):\n",
    "                # (4)未达到阈值条件，重复迭代\n",
    "                r = self.update_r(_mean, _covariances, _alpha)\n",
    "                # (3)M步：计算新一轮迭代的模型参数\n",
    "                _mean, _covariances, _alpha = self.update_params(r)\n",
    "            else:\n",
    "                # 达到阈值条件，停止迭代\n",
    "                self.n_iter_ = i\n",
    "                break\n",
    "\n",
    "            self.mean_ = _mean\n",
    "            self.covariances_ = _covariances\n",
    "            self.alpha_ = _alpha\n",
    "\n",
    "    def score(self):\n",
    "        \"\"\"计算该局部最优解的score，即似然函数值\"\"\"\n",
    "        return (self.alpha_ * self.gaussian(self.mean_, self.covariances_)).sum()\n",
    "\n",
    "    \n",
    "# 观测数据\n",
    "y = np.array([-67, -48, 6, 8, 14, 16, 23, 24, 28,\n",
    "             29, 41, 49, 56, 60, 75]).reshape(1, 15)\n",
    "# 预估均值和方差，以其邻域划分寻优范围\n",
    "y_mean = y.mean() // 1\n",
    "y_std = (y.std() ** 2) // 1\n",
    "\n",
    "# 网格搜索，对不同的初值进行参数估计\n",
    "alpha = [[i, 1 - i] for i in np.linspace(0.1, 0.9, 9)]\n",
    "mean = [[y_mean + i, y_mean + j]\n",
    "        for i in range(-10, 10, 5) for j in range(-10, 10, 5)]\n",
    "covariances = [[y_std + i, y_std + j]\n",
    "               for i in range(-1000, 1000, 500) for j in range(-1000, 1000, 500)]\n",
    "results = []\n",
    "# itertools.product 函数是 Python 标准库 itertools 中的一个函数，它用于计算多个可迭代对象的笛卡尔积。简单来说，它会返回输入的可迭代对象的所有可能组合。\n",
    "# product 函数的参数可以是多个可迭代对象，或者是一个可迭代对象的元组。它还接受一个可选参数 repeat，用于指定重复次数，即每个可迭代对象的元素重复组合的次数。\n",
    "for i in itertools.product(alpha, mean, covariances):\n",
    "    init_alpha = i[0]\n",
    "    init_mean = i[1]\n",
    "    init_covariances = i[2]\n",
    "    clf = MyGMM(alphas_init=init_alpha, means_init=init_mean, covariances_init=init_covariances,\n",
    "                n_components=2, tol=1e-6)\n",
    "    clf.fit(y)\n",
    "    # 得到不同初值收敛的局部最优解\n",
    "    results.append([clf.alpha_, clf.mean_, clf.covariances_, clf.score()])\n",
    "# 根据score，从所有局部最优解找到相对最优解\n",
    "best_value = max(results, key=lambda x: x[3])\n",
    "\n",
    "print(\"alpha : {}\".format(best_value[0].T))\n",
    "print(\"mean : {}\".format(best_value[1].T))\n",
    "print(\"std : {}\".format(best_value[2].T))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
